import torch
import torch.nn as nn
import torch.nn.functional as F


class Critic(nn.Module):
    def __init__(self, args):
        super(Critic, self).__init__()
        if args.scenario_name in ['GuessingNumber']:
            self.fc1 = nn.Linear(args.obs_shape, args.hidden_size)
            self.fc2 = nn.Linear(args.hidden_size, args.hidden_size)
        elif args.scenario_name == 'RevealingGoal':
            conv_layers = [
                nn.Conv2d(in_channels=args.obs_shape[0], out_channels=16, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
                nn.Flatten()
            ]
            self.fc1 = nn.Sequential(*conv_layers)
            self.fc2 = nn.Linear(16 * args.obs_shape[1] * args.obs_shape[2], args.hidden_size)
        else:
            raise NotImplementedError
        self.rnn = nn.GRU(args.hidden_size, args.hidden_size, args.rnn_layer)
        self.q_out = nn.Linear(args.hidden_size, args.action_shape)

        self.hidden_size = args.hidden_size
        self.args = args


    def forward(self, state, a_u, hidden, batch_size=0, eps=0, select_id=None):
        if self.args.scenario_name in ['GuessingNumber']:
            pass
        elif self.args.scenario_name == 'RevealingGoal':
            bs = state.shape[0]
            state = state.reshape(bs, self.args.obs_shape[0], self.args.obs_shape[1], self.args.obs_shape[2])
        else:
            raise NotImplementedError

        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))

        if batch_size == 0:
            x = x.unsqueeze(0)
            x, hidden = self.rnn(x, hidden)
            x = x.squeeze(0)
        else:
            x = x.reshape(-1, batch_size, self.hidden_size)
            x, hidden = self.rnn(x, hidden)
            x = x.reshape(-1, self.hidden_size)

        a_value = self.q_out(x)

        q_value = a_value

        legal_adv = (1 + q_value - q_value.min()) * a_u

        if select_id is not None:
            selected_logits = q_value.gather(-1, select_id)
            return selected_logits, hidden

        greedy_action_id = legal_adv.argmax(dim=-1)
        if eps > 0:
            random_action = a_u.multinomial(1).squeeze(1)

            rand = torch.rand(greedy_action_id.size(), device=greedy_action_id.device)
            rand = (rand < eps).long()
            action_id = (greedy_action_id * (1 - rand) + random_action * rand).detach().long()
        else:
            action_id = greedy_action_id.detach().long()

        selected_logits = q_value.gather(-1, action_id.unsqueeze(-1))

        return action_id, selected_logits, hidden
